# --------------------------------------------------------
# modified from Hora

import numpy as np
import os



from hora.algo.models.models import ActorCritic, ActorCriticMaskCompensator, MLP
from hora.algo.models.running_mean_std import RunningMeanStd

from hora.utils.misc import tprint



from deoxys import config_root
from deoxys.franka_interface import FrankaInterface
from deoxys.utils import YamlConfig
from deoxys.utils.input_utils import input2action
from deoxys.utils.io_devices import SpaceMouse
from deoxys.utils.log_utils import get_deoxys_example_logger
from deoxys.utils.ik_utils import IKWrapper



import torch
import torch.nn as nn

import sys
sys.path.append('../LEAP_Hand_API/python')
from leap_hand_utils.dynamixel_client import *
import leap_hand_utils.leap_hand_utils as lhu
import time

import pytorch_kinematics as pk

try:
    import rospy
    import tf2_ros
except:
    pass

def _obs_allegro2hora(obses):
    obs_index = obses[0:4]
    obs_middle = obses[4:8]
    obs_ring = obses[8:12]
    obs_thumb = obses[12:16]
    obses = np.concatenate([obs_index, obs_thumb, obs_middle, obs_ring]).astype(np.float32)
    return obses


def _action_hora2allegro(actions):
    cmd_act = actions.copy()
    cmd_act[[4, 5, 6, 7]] = actions[[8, 9, 10, 11]]
    cmd_act[[12, 13, 14, 15]] = actions[[4, 5, 6, 7]]
    cmd_act[[8, 9, 10, 11]] = actions[[12, 13, 14, 15]]
    return cmd_act



def quat_conjugate(a):
    shape = a.shape
    a = a.reshape(-1, 4)
    return torch.cat((-a[:, :3], a[:, -1:]), dim=-1).view(shape)


def normalize(x, eps: float = 1e-9):
    return x / x.norm(p=2, dim=-1).clamp(min=eps, max=None).unsqueeze(-1)

def quat_unit(a):
    return normalize(a)


def quat_from_angle_axis(angle, axis):
    theta = (angle / 2).unsqueeze(-1)
    xyz = normalize(axis) * theta.sin()
    w = theta.cos()
    return quat_unit(torch.cat([xyz, w], dim=-1))



def quat_mul(a, b):
    assert a.shape == b.shape
    shape = a.shape
    a = a.reshape(-1, 4)
    b = b.reshape(-1, 4)

    x1, y1, z1, w1 = a[:, 0], a[:, 1], a[:, 2], a[:, 3]
    x2, y2, z2, w2 = b[:, 0], b[:, 1], b[:, 2], b[:, 3]
    ww = (z1 + x1) * (x2 + y2)
    yy = (w1 - y1) * (w2 + z2)
    zz = (w1 + y1) * (w2 - z2)
    xx = ww + yy + zz
    qq = 0.5 * (xx + (z1 - x1) * (x2 - y2))
    w = qq - ww + (z1 - y1) * (y2 - z2)
    x = qq - xx + (x1 + w1) * (x2 + w2)
    y = qq - yy + (w1 - x1) * (y2 + z2)
    z = qq - zz + (z1 + y1) * (w2 - x2)

    quat = torch.stack([x, y, z, w], dim=-1).view(shape)

    return quat

def copysign(a, b):
    # type: (float, Tensor) -> Tensor
    a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0])
    return torch.abs(a) * torch.sign(b)


def get_euler_xyz(q):
    qx, qy, qz, qw = 0, 1, 2, 3
    # roll (x-axis rotation)
    sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz])
    cosr_cosp = q[:, qw] * q[:, qw] - q[:, qx] * \
        q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz]
    roll = torch.atan2(sinr_cosp, cosr_cosp)

    # pitch (y-axis rotation)
    sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx])
    pitch = torch.where(torch.abs(sinp) >= 1, copysign(
        np.pi / 2.0, sinp), torch.asin(sinp))

    # yaw (z-axis rotation)
    siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy])
    cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * \
        q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz]
    yaw = torch.atan2(siny_cosp, cosy_cosp)

    return roll % (2*np.pi), pitch % (2*np.pi), yaw % (2*np.pi)


def compute_euler_diff_from_quats(quat1, quat2):
    # bsz x 4, with bsz x 4
    quat_mul_res = quat_mul(quat2, quat_conjugate(quat1))
    euler_x, euler_y, euler_z = get_euler_xyz(quat_mul_res)
    euler_xyz = torch.stack(
        [ euler_x, euler_y, euler_z ], dim=-1
    )
    return euler_xyz



sys.path.append('../IsaacGymEnvs2')




from isaacgymenvs.ddim.runners.diffusion_controlseq import Diffusion as DiffusionControlSeq
from isaacgymenvs.ddim.models.diffusion_controlseq import WorldModel
from isaacgymenvs.ddim.models.ema import EMAHelper
import yaml
from isaacgymenvs.ddim.main import dict2namespace
import pickle

class HardwarePlayerInvdynV1(object):
    def __init__(self, config):
        self.action_scale = 1 / 24
        self.actions_num = 16
        self.device = 'cuda'

        self.model_arch = 'resmlp'
        # self.model_arch = 'resmlp_moe'
        
        self.invdyn_v2_config_path = 'controlseq.yml'
        with open(os.path.join("../IsaacGymEnvs2/isaacgymenvs/ddim/configs", self.invdyn_v2_config_path), "r") as f:
            config = yaml.safe_load(f)
        invdyn_config = dict2namespace(config)
        invdyn_config.device = self.device
        invdyn_config.invdyn.model_arch = self.model_arch
        invdyn_config.invdyn.res_blocks = 2
        invdyn_config.invdyn.pred_extrin = False
        
        self.history_length = 10
        self.w_hand_root_ornt = False
        # self.w_hand_root_ornt = True
        
        
        self.hand_root_tsr = torch.tensor([0.5, 0.5, -0.5, 0.5], device=self.device).float()
        
        
        hand_wrist_ornt = 'palm_down'

        
        y_unit_tensor = torch.tensor([0, 1, 0], device=self.device).float()
        z_unit_tensor = torch.tensor([0, 0, 1], device=self.device).float()
        x_unit_tensor = torch.tensor([1, 0, 0], device=self.device).float()
        
        if hand_wrist_ornt == 'palm_down':
            x_mult_val = torch.tensor([0.0], device=self.device).float()
            y_mult_val = torch.tensor([0.0], device=self.device).float()
            z_mult_val = torch.tensor([0.0], device=self.device).float()
        elif hand_wrist_ornt == 'thumb_up':
            x_mult_val = torch.tensor([0.0], device=self.device).float()
            y_mult_val = torch.tensor([1.5], device=self.device).float()
            z_mult_val = torch.tensor([0.0], device=self.device).float()
        elif hand_wrist_ornt == 'thumb_down':
            x_mult_val = torch.tensor([1.0], device=self.device).float()
            y_mult_val = torch.tensor([1.5], device=self.device).float()
            z_mult_val = torch.tensor([0.0], device=self.device).float()
        elif hand_wrist_ornt == 'base_up':
            x_mult_val = torch.tensor([0.5], device=self.device).float()
            y_mult_val = torch.tensor([1.5], device=self.device).float()
            z_mult_val = torch.tensor([0.0], device=self.device).float()
        elif hand_wrist_ornt == 'base_down':
            x_mult_val = torch.tensor([-0.5], device=self.device).float()
            y_mult_val = torch.tensor([1.5], device=self.device).float()
            z_mult_val = torch.tensor([0.0], device=self.device).float()
        else:
            raise ValueError(f"Invalid hand wrist orientation: {hand_wrist_ornt}")
        
        z_rot_tensor = quat_from_angle_axis(z_mult_val * np.pi, z_unit_tensor.unsqueeze(0)) # .squeeze(0)
        y_rot_tensor = quat_from_angle_axis(y_mult_val * np.pi, y_unit_tensor.unsqueeze(0)) # .squeeze(0)
        x_rot_tensor = quat_from_angle_axis(x_mult_val * np.pi, x_unit_tensor.unsqueeze(0)) # .squeeze(0)
        hand_rot_quat = quat_mul(quat_mul(x_rot_tensor, y_rot_tensor), z_rot_tensor)
        self.hand_root_tsr = quat_mul(hand_rot_quat, self.hand_root_tsr.unsqueeze(0)).squeeze(0)

        
        invdyn_config.invdyn.w_hand_root_ornt = self.w_hand_root_ornt
        invdyn_config.invdyn.history_length = self.history_length #  10
        invdyn_config.invdyn.future_length = 2
        invdyn_config.invdyn.res_blocks = 5
        
        self.history_length = invdyn_config.invdyn.history_length
        self.future_length = invdyn_config.invdyn.future_length
        
        self.invdyn_pred_extrin = True
        self.mask_obj_motion = True
        self.hist_extrin_length = 30
        
        self.normalize_input = False
        self.normalize_output = False
        
        self.invdyn_pred_extrin = False
        
        # cond on obj motion; cond on hand motion
        class dummy:
            def __init__(self):
                self.log_path = ''
                self.sample_type = 'generalized'
                self.skip_type = 'uniform'
                self.timesteps = 50
                self.eta = 0
                self.model_type = 'invdyn'
                # self.optimize_via_fingertip_pos 
        
        
        invdyn_args = dummy()
        self.hist_context_length = 0
        
        self.invdyn_v2_log_path = '' 
        
        
        self.tune_bc_via_compensator_model = False
        
        
        
        invdyn_config.invdyn.hist_context_length = self.hist_context_length
        
        self.hand_type = 'leap'
        
        
        
        self.replay_wave_type = 'policy'
        
        
        ### NOTE: test openloop replay ###
        self.test_openloop_replay = False
        # self.test_openloop_replay = True
        
        self.test_with_sim_wm_for_prediction = False
        # self.test_with_sim_wm_for_prediction = True
        
        self.test_with_real_wm_for_prediction = False
        # self.test_with_real_wm_for_prediction = True
        
        
        self.test_exp_idx = 0
        self.init_test_exp_idx = 0 
        
        
        self.openloop_replay_interp_coef = 0.0

        
        self.openloop_replay_add_noise = False
        self.openloop_replay_add_noise = True
        
        self.openloop_replay_noise_scale = 1/ 100 

        replay_experiences_fn = ''
        
        self.openloop_replay_sv_folder = f""
        
        
        
        
        self.test_openloop_replay_test_compensator = ''
        
        
        replay_experiences = np.load(replay_experiences_fn, allow_pickle=True).item()
        replay_actions = replay_experiences['shadow_hand_dof_tars']
        self.replay_actions = torch.from_numpy(replay_actions).float().to(self.device) 
        self.replay_states = torch.from_numpy(replay_experiences['shadow_hand_dof_pos']).float().to(self.device)
        
        self.target_joint_idx = -1
        if self.target_joint_idx >= 0:
            zero_out_jts_idxes = [_ for _ in range(0, self.replay_actions.size(-1)) if _ != self.target_joint_idx]
            zero_out_jts_idxes = torch.tensor(zero_out_jts_idxes, device=self.device).long()
            self.replay_actions[..., zero_out_jts_idxes] = 0.0
            self.replay_states[..., zero_out_jts_idxes] = 0.0
        
        self.rolling_out_policy = False
        # self.rolling_out_policy = True
        self.X_pressed = False
        self.Y_pressed = False
        self._init_tf_listener()
        
        self.cat_rolling_out_dict_sv_folder = ""
        self.cat_rolling_out_dict_sv_folder_bad = self.cat_rolling_out_dict_sv_folder + "_bad"
        ### NOTE: test rolling out policy ###
        
        
        ### NOTE: set the preset grasp data fn ###
        self.preset_grasp_data_fn = None
        self.preset_grasp_data_fn = "cache/leap_down_init0d38_cuboid_default_0_grasp_1k_s08.npy"
        ### NOTE: set the preset grasp data fn ###
        
        
        self.use_delta_action_model = False
        # self.use_delta_action_model = True
        # 
        self.use_delta_action_model_policy = False
        # self.use_delta_action_model_policy = True
        
        self.use_joint_delta_action_model_policy = False
        # self.use_joint_delta_action_model_policy = True
        
        self.hierarchical_compensator = False
        # self.hierarchical_compensator = True
        
        self.delta_action_model_using_abs_target = False
        self.delta_action_model_using_abs_target = True
        
        self.action_compensator_w_full_hand = False
        self.action_compensator_w_full_hand = True
        
        self.multi_action_compensator_w_full_hand = False
        # self.multi_action_compensator_w_full_hand = True
        
        self.first_level_scale = 1. / 24
        
        self.action_compensator_scale = 1.0 / 24
        # self.action_compensator_scale = 1.0 / 48
        # self.action_compensator_scale = 1.0 / 480
        # self.action_compensator_scale = 1.0 / 960
        
        if self.use_delta_action_model:
            from isaacgymenvs.ddim.models.diffusion_controlseq import WorldModelDeltaActions
            
            self.finger_idx = -1
            
            self.joint_idx = 8
            
            
            self.hist_context_length = 0
            
            
            self.wm_history_length = 1
            self.wm_history_length = 2
            
            self.hist_context_finger_idx = -1
            
            if self.hist_context_finger_idx >= 0:
                hist_context_finger_joint_idxes = [ _ for _ in range(self.hist_context_finger_idx * 4, (self.hist_context_finger_idx + 1 ) * 4) ]
                hist_context_finger_joint_idxes = torch.tensor(hist_context_finger_joint_idxes, dtype=torch.long).to(self.device)
                self.hist_context_finger_joint_idxes = hist_context_finger_joint_idxes
            else:
                hist_context_finger_joint_idxes = [ _ for _ in range(0, 16) ]
                hist_context_finger_joint_idxes = torch.tensor(hist_context_finger_joint_idxes, dtype=torch.long).to(self.device)
                self.hist_context_finger_joint_idxes = hist_context_finger_joint_idxes
            
            invdyn_config.invdyn.wm_history_length = self.wm_history_length
            invdyn_config.invdyn.hist_context_length = self.hist_context_length
            
             
            if self.action_compensator_w_full_hand:
                invdyn_v2_config_path = 'controlseq.yml'
                with open(os.path.join("../IsaacGymEnvs2/isaacgymenvs/ddim/configs", invdyn_v2_config_path), "r") as f:
                    config = yaml.safe_load(f)
                compensator_invdyn_config = dict2namespace(config)
                self.compensator_history_length = 2
                self.compensator_history_length = 10
                self.hist_context_length = 0
                compensator_invdyn_config.invdyn.history_length = 10
                compensator_invdyn_config.invdyn.history_obs_dim = 32
                compensator_invdyn_config.invdyn.res_blocks = 5
                compensator_invdyn_config.invdyn.model_arch = 'resmlp'
                compensator_invdyn_config.device = 'cuda'
                compensator_invdyn_config.invdyn.pred_extrin = False
                
                compensator_invdyn_config.invdyn.finger_idx = -1
                compensator_invdyn_config.invdyn.joint_idx = -1
                compensator_invdyn_config.invdyn.wm_history_length = self.compensator_history_length #  self.wm_history_length
                compensator_invdyn_config.invdyn.hist_context_length = self.hist_context_length + 0 # full hand 
                
                self.action_compensator_scale = 0.2
                self.action_compensator_scale = 0.05
                
                self.action_compensator_scale = 1/24.
                
                self.delta_action_model_full_hand = WorldModelDeltaActions(compensator_invdyn_config).cuda()
                self.delta_action_model_full_hand.eval()
                
                
                self.delta_action_model_full_hand_ckpt_fn = '' 
                
                
                
                if 'logs/cs' in self.delta_action_model_full_hand_ckpt_fn:
                    self.action_compensator_scale =  1/24.
                    self.delta_action_model_full_hand.load_state_dict(torch.load(self.delta_action_model_full_hand_ckpt_fn)[0])
                    if invdyn_config.model.ema:
                        ema_helper = EMAHelper(mu=invdyn_config.model.ema_rate)
                        ema_helper.register(self.delta_action_model_full_hand)
                        # ema_helper.load_state_dict(torch.load(self.delta_action_model_full_hand_ckpt_fn)[-1])
                        
                        loaded_state_dict = torch.load(self.delta_action_model_full_hand_ckpt_fn)[-1]
                        cleaned_state_dict = {}
                        for key in loaded_state_dict:
                            if 'module' in key:
                                modified_key = ".".join(key.split('.')[1:])
                            else:
                                modified_key = key
                            cleaned_state_dict[modified_key] = loaded_state_dict[key]
                        ema_helper.load_state_dict(cleaned_state_dict)
                        
                        ema_helper.ema(self.delta_action_model_full_hand)
                else:
                    self.delta_action_model_full_hand.load_state_dict(torch.load(self.delta_action_model_full_hand_ckpt_fn))
                
                self.delta_action_model_full_hand.eval()
                
                self.action_compensator_target_joint_idxes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

                self.action_compensator_target_joint_idxes = torch.tensor(self.action_compensator_target_joint_idxes, dtype=torch.long).cuda()
                self.action_compensator_joint_mask = torch.zeros((16,), dtype=torch.float32).cuda()
                self.action_compensator_joint_mask[self.action_compensator_target_joint_idxes] = 1.0
            
            if self.use_delta_action_model_policy:
                if self.use_joint_delta_action_model_policy:
                    if self.hierarchical_compensator:
                        self._load_action_compensator_policy()
                    self._load_joint_action_compensator_policy()
                else:
                    self._load_action_compensator_policy()
            self.compensated_res_abs_running_mean_std = RunningMeanStd((16, )).to(self.device)
            self.compensated_res_abs_running_mean_std.train()
        
        invdyn_args.log_path = self.invdyn_v2_log_path
        
        
        self.invdyn_obs_type = invdyn_config.invdyn.obs_type
        self.invdyn_future_type = 'obj_motion'
        invdyn_config.invdyn.future_ref_dim = 3
        invdyn_config.invdyn.pred_extrin = self.invdyn_pred_extrin
        
        self.invdynmodel = DiffusionControlSeq(invdyn_args, invdyn_config) 
        ckpt_fn = None
        self.invdynmodel.init_models(ckpt_fn=ckpt_fn)
        
        
        
        ##### Initialize the object state predictor #####
        self.obj_state_predictor = False

        self.history_qpos_buf = torch.zeros((invdyn_config.invdyn.history_length, 16), dtype=torch.float32).to(self.device)
        self.history_qtars_buf = torch.zeros(( invdyn_config.invdyn.history_length, 16), dtype=torch.float32).to(self.device)
        self.extrin_history_qpos_buf = torch.zeros((self.hist_extrin_length, 16), dtype=torch.float32).to(self.device)
        self.extrin_history_qtars_buf = torch.zeros((self.hist_extrin_length, 16), dtype=torch.float32).to(self.device)
        self.history_obj_pose_buf = torch.zeros((invdyn_config.invdyn.history_length, 7), dtype=torch.float32).to(self.device) 
        
        if self.hist_context_length > 0:
            self.hist_context_qpos_buf = torch.zeros((self.hist_context_length, 16), dtype=torch.float32).to(self.device)
            self.hist_context_qtars_buf = torch.zeros((self.hist_context_length, 16), dtype=torch.float32).to(self.device) 
        
        
        
        self.obj_motion_format = 'rot_axis'
        
        obj_motion_scale = 0.07
        obj_motion_scale = 0.05
        # obj_motion_scale = 0.03
        
        
        if self.obj_motion_format == 'rot_axis':
            self.future_obj_motion_buf = torch.tensor(
                [.0, -1.0, 0.0], dtype=torch.float32
            ).unsqueeze(0).repeat(self.future_length, 1).contiguous().cuda()
            self.future_obj_motion_buf = self.future_obj_motion_buf / self.future_obj_motion_buf.norm(dim=-1, p=2, keepdim=True)

        self.canonical_arm_with_hand_state = [0, -0.7853, 0, -2.35539, 0, 1.57,0.785] + [
            1.244, 0.082, 0.265, 0.298, 1.163, 1.104, 0.953, -0.138,
            1.096, 0.005, 0.080, 0.150, 1.337, 0.029, 0.285, 0.317,
        ]
        self.canonical_arm_with_hand_state = np.array(self.canonical_arm_with_hand_state).astype(np.float32)
        
        if self.hand_type == 'leap':
            self.allegro_dof_lower = torch.from_numpy(np.array([
                -0.3140, -1.0470, -0.5060, -0.3660, -0.3490, -0.4700, -1.2000, -1.3400,
                -0.3140, -1.0470, -0.5060, -0.3660, -0.3140, -1.0470, -0.5060, -0.3660            
            ])).to(self.device)
            self.allegro_dof_upper = torch.from_numpy(np.array([
                2.2300, 1.0470, 1.8850, 2.0420, 2.0940, 2.4430, 1.9000, 1.8800, 2.2300,
                1.0470, 1.8850, 2.0420, 2.2300, 1.0470, 1.8850, 2.0420
            ])).to(self.device)
        else:
            raise NotImplementedError(f"Unsupported hand type: {self.hand_type}")
        
        joint_idxes_ordering = [1, 0, 2, 3, 9, 8, 10, 11, 13, 12, 14, 15, 4, 5, 6, 7]
        joint_idxes_ordering = np.array(joint_idxes_ordering).astype(np.int32)
        self.joint_idxes_ordering = joint_idxes_ordering
        joint_idxes_inversed_ordering = np.argsort(joint_idxes_ordering)
        self.joint_idxes_inversed_ordering = joint_idxes_inversed_ordering

    
    def _load_joint_action_compensator_policy(self, ):
        
        # use the compensator finger 
        self.compensator_finger_idx = 2
        
        # wm configuration; output joint idxes #
        self.wm_history_length = 1
        
        self.action_compensator_input_joint_idx = -1
        self.action_compensator_input_finger_idx = -1
        self.action_compensator_output_joint_idx = -1
        self.action_compensator_output_finger_idx = -1
        
        # self.action_compensator_input_joint_idx = 10
        # self.action_compensator_output_joint_idx = 10
        
        # self.compensator_output_joint_idxes = torch.tensor(
        #     [ 8, 12 ], dtype=torch.long
        # ).cuda()
        
        self.compensator_output_joint_idxes = torch.tensor(
            [ 8, 12, 13 ], dtype=torch.long
        ).cuda()
        
        # self.compensator_output_joint_idxes = torch.tensor(
        #     [ 8 ], dtype=torch.long
        # ).cuda()
        
        # self.compensator_output_joint_idxes = torch.tensor(
        #     [ 12 ], dtype=torch.long
        # ).cuda()
        
        obs_shape = (32 + 16, )
        # obs_shape = (compensator_input_dim_finger, )
        policy_net_config = {
            'actions_num': self.compensator_output_joint_idxes.shape[0], # action number #
            'input_shape': obs_shape, # observation shape #
            'actor_units': [512, 256, 128],
            'priv_mlp_units': [256, 128, 8],
            'priv_info': False,
            'proprio_adapt': False,
            'priv_info_dim': 9,
            'target_joint_idx_tensor': self.compensator_output_joint_idxes,
        }

        self.compensator = ActorCriticMaskCompensator(policy_net_config)
        self.compensator.to(self.device)
        self.compensator_running_mean_std = RunningMeanStd(obs_shape).to(self.device)
        
        
        
        self.joint_action_compensator_ckpt_fn = 'outputs/LeapHora/debug_true_cuboidthin_wmfromcubhoidthincylinderdefault_compensator_jt81213/stage1_nn/best.pth'
        
        cur_action_compensator_ckpt = torch.load(self.joint_action_compensator_ckpt_fn)
        self.compensator.load_state_dict(cur_action_compensator_ckpt['model'])
        self.compensator_running_mean_std.load_state_dict(cur_action_compensator_ckpt['running_mean_std'])

        self.compensator.eval()
        self.compensator_running_mean_std.eval()
    
    
    def _inference_joint_action_compensator_policy(self, cur_ori_targets):
        compensated_actions = cur_ori_targets.clone()
        
        hist_qpos = self.history_qpos_buf[-self.wm_history_length: ,  :]
        hist_qtars = self.history_qtars_buf[-self.wm_history_length: , :]
        
        unscaled_hist_qpos = self._unscale(hist_qpos, self.allegro_dof_lower, self.allegro_dof_upper).float()
        cat_hist_obs = torch.cat([unscaled_hist_qpos, hist_qtars], dim=-1)
        cat_hist_obs = cat_hist_obs.contiguous().view(-1).contiguous().unsqueeze(0)
        cat_hist_obs_w_targets = torch.cat([cat_hist_obs, cur_ori_targets], dim=-1)
        
        compensator_input_dict = {
            'obs': self.compensator_running_mean_std(cat_hist_obs_w_targets).float()
        }
        delta_action = self.compensator.act_inference(compensator_input_dict)
        delta_action = torch.clamp(delta_action, -1.0 * torch.ones_like(delta_action), torch.ones_like(delta_action))
        compensated_actions[..., self.compensator_output_joint_idxes] = cur_ori_targets[..., self.compensator_output_joint_idxes] +  self.action_compensator_scale * delta_action # 1./48 * delta_action #  1./24 * delta_action
        
        return compensated_actions
    
    
    def _inference_action_compensator_policy(self, cur_ori_targets, action_scale=None):
        compensated_actions = cur_ori_targets.clone()
        
        if action_scale is None:
            action_scale = self.action_compensator_scale
        
        # print(f"inferring action compensator")
        for cur_finger_idx in self.finger_idx_to_compensator:
            cur_finger_hist_qpos = self.history_qpos_buf[-self.wm_history_length: , self.finger_idx_to_compensator_input_joint_idxes[cur_finger_idx]]
            cur_finger_hist_qtars = self.history_qtars_buf[-self.wm_history_length: , self.finger_idx_to_compensator_input_joint_idxes[cur_finger_idx]]
            
            unscaled_hist_qpos = self._unscale(cur_finger_hist_qpos, self.allegro_dof_lower[self.finger_idx_to_compensator_input_joint_idxes[cur_finger_idx]], self.allegro_dof_upper[self.finger_idx_to_compensator_input_joint_idxes[cur_finger_idx]]).float()
            cat_hist_obs = torch.cat(
                [ unscaled_hist_qpos, cur_finger_hist_qtars ], dim=-1
            )
            cat_hist_obs = cat_hist_obs.contiguous().view(-1).contiguous().unsqueeze(0)
            cat_hist_obs_w_targets = torch.cat(
                [ cat_hist_obs, cur_ori_targets[..., self.finger_idx_to_compensator_input_joint_idxes[cur_finger_idx]] ], dim=-1
            )
            compensator_input_dict = {
                'obs': self.finger_idx_to_compensator_running_mean_std[cur_finger_idx](cat_hist_obs_w_targets).float()
            }
            delta_action = self.finger_idx_to_compensator[cur_finger_idx].act_inference(compensator_input_dict)
            delta_action = torch.clamp(delta_action, -1.0 * torch.ones_like(delta_action), torch.ones_like(delta_action))
            compensated_actions[..., self.finger_idx_to_compensator_input_joint_idxes[cur_finger_idx]] = cur_ori_targets[..., self.finger_idx_to_compensator_input_joint_idxes[cur_finger_idx]] +    action_scale    * delta_action
        
        for cur_joint_idx in self.joint_idx_to_compensator:
            cur_joint_hist_qpos = self.history_qpos_buf[-self.compensator_history_length: , self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]]
            cur_joint_hist_qtars = self.history_qtars_buf[-self.compensator_history_length: , self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]]
            
            if self.test_openloop_replay_test_compensator is not None:
                cur_joint_hist_qpos = self.test_openloop_replay_test_compensator_states[self.replay_test_compensator_traj_idx, self.test_openloop_replay_test_compensator_step, self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]].unsqueeze(0)
                cur_joint_hist_qtars = self.test_openloop_replay_test_compensator_actions[self.replay_test_compensator_traj_idx, self.test_openloop_replay_test_compensator_step, self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]].unsqueeze(0)
            
            unscaled_hist_qpos = self._unscale(cur_joint_hist_qpos, self.allegro_dof_lower[self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]], self.allegro_dof_upper[self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]]).float()
            cat_hist_obs = torch.cat(
                [ unscaled_hist_qpos, cur_joint_hist_qtars ], dim=-1
            )
            cat_hist_obs = cat_hist_obs.contiguous().view(-1).contiguous().unsqueeze(0)
            
            if self.test_openloop_replay_test_compensator is not None:
                cat_hist_obs_w_targets = torch.cat(
                    [ cat_hist_obs, self.test_openloop_replay_test_compensator_actions[self.replay_test_compensator_traj_idx, self.test_openloop_replay_test_compensator_step + 1, self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]].unsqueeze(0) ], dim=-1
                )
                if self.test_openloop_replay_test_compensator_step == self.test_openloop_replay_test_compensator_states.shape[1] - 2:
                    self.test_openloop_replay_test_compensator_step = 0
                    self.replay_test_compensator_traj_idx += 1
                else:
                    self.test_openloop_replay_test_compensator_step += 1
            else:
                cat_hist_obs_w_targets = torch.cat(
                    [ cat_hist_obs, cur_ori_targets[..., self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]] ], dim=-1
                )

            compensator_input_dict = {
                'obs': self.joint_idx_to_compensator_running_mean_std[cur_joint_idx](cat_hist_obs_w_targets).float()
            }
            delta_action = self.joint_idx_to_compensator[cur_joint_idx].act_inference(compensator_input_dict)
            delta_action = torch.clamp(delta_action, -1.0 * torch.ones_like(delta_action), torch.ones_like(delta_action))
            compensated_actions[..., self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]] = cur_ori_targets[..., self.joint_idx_to_compensator_input_joint_idxes[cur_joint_idx]] + action_scale * delta_action
        
        
        
        return compensated_actions
        
        
    
    def _obs_leap2hora(self, obses):
        for i_fr in range(obses.shape[0]):
            obses[i_fr] = obses[i_fr] - np.pi
        obses = obses[self.joint_idxes_inversed_ordering]
        return obses
    
    def _action_hora2leap(self, actions):
        actions = actions[self.joint_idxes_ordering]
        return actions
    
    def _unscale(self, x, lower, upper):
        return (2.0 * x - upper - lower) / (upper - lower)
    
    def _scale(self, x, lower, upper):
        return (x * (upper - lower) + upper + lower) / 2.0
        # return (2.0 * x - upper - lower) / (upper - lower)
        
    def parse_quality_selection_button(self, ):
        cur_button_info = self.tf_buffer.lookup_transform("world", f"pressed_left_button", rospy.Time(0), rospy.Duration(0.5))
        X_pressed = cur_button_info.transform.translation.x
        Y_pressed = cur_button_info.transform.translation.y
        
        self.X_pressed = True if X_pressed == 1 else False
        self.Y_pressed = True if Y_pressed == 1 else False
        
        
    def _init_tf_listener(self, ):
        rospy.init_node('example')
        self.tf_buffer = tf2_ros.Buffer()
        self.listener = tf2_ros.TransformListener(self.tf_buffer)
    
    
    def reset_to_initial_state(self, init_grasp_idx):
        cur_targets_to_set = self.grasp_data[init_grasp_idx, :16][self.joint_idxes_ordering].copy()
        self.leap_hand.set_allegro(cur_targets_to_set)
        time.sleep(0.5)
        real_leap_pos = self.leap_hand.read_pos()
        obses = self._obs_leap2hora(real_leap_pos)
        obses = torch.from_numpy(obses.astype(np.float32)).cuda()
        if self.invdyn_pred_extrin:
            cur_obs_buf = self._unscale(obses.clone(), self.allegro_dof_lower, self.allegro_dof_upper)
            # cur_qtar_buf = cur_obs_buf.clone()
            if self.normalize_input:
                self.history_qpos_buf[:, :] = cur_obs_buf.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            else:
                self.history_qpos_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            self.history_qtars_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qtars_buf.size(0), 1)
            
            self.extrin_history_qpos_buf[:, :] = cur_obs_buf.unsqueeze(0).repeat(self.extrin_history_qpos_buf.size(0), 1)
            self.extrin_history_qtars_buf[:, :] = obses.unsqueeze(0).repeat(self.extrin_history_qtars_buf.size(0), 1)
        else:
            if self.normalize_input:
                cur_obs_buf = self._unscale(obses.clone(), self.allegro_dof_lower, self.allegro_dof_upper)
                self.history_qpos_buf[:, :] = cur_obs_buf.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            else:
                self.history_qpos_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            self.history_qtars_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qtars_buf.size(0), 1)
        
        self.rolling_out_qpos = [obses.clone()]
        self.rolling_out_qtars = [obses.clone()]
        self.rolling_out_pred_sim_states = [obses.clone()]
        self.rolling_out_pred_real_wm_states_wo_compensator = [obses.clone()]
        self.init_obj_pose = self.grasp_data[init_grasp_idx, 16: ].copy()
        
        
    def get_real_wm_pred_output(self, cur_nex_input_target):
        hist_wm_qpos = self.history_qpos_buf[- self.wm_invdyn_history_length: ] # wm_length x nn_dofs #
        unscaled_hist_wm_qpos = self._unscale(hist_wm_qpos, self.allegro_dof_lower, self.allegro_dof_upper).float()
        if self.wm_invdyn_history_length == 1:
            hist_wm_qtars = cur_nex_input_target.clone() # .unsqueeze(0)
        else:
            hist_wm_qtars = torch.cat(
                [ self.history_qtars_buf[- self.wm_invdyn_history_length + 1: , ] , cur_nex_input_target], dim=0
            ) # wm_hist_length x nn_dofs #
        unscaled_hist_wm_qtars = self._unscale(hist_wm_qtars, self.allegro_dof_lower, self.allegro_dof_upper).float()
        tot_next_joint_state_pred = []
        for joint_idx in self.joint_idx_to_real_wm:
            cur_jt_wm_hist_qpos = unscaled_hist_wm_qpos[..., joint_idx]
            cur_jt_wm_hist_qtars = unscaled_hist_wm_qtars[..., joint_idx]
            cur_jt_wm_input_dict = {
                'state': cur_jt_wm_hist_qpos.unsqueeze(0), 'action': cur_jt_wm_hist_qtars.unsqueeze(0)
            }
            cur_jt_pred_output = self.joint_idx_to_real_wm[joint_idx](cur_jt_wm_input_dict)
            cur_jt_pred_output = self._scale(cur_jt_pred_output, self.allegro_dof_lower[[joint_idx]], self.allegro_dof_upper[[joint_idx]]).float()
            tot_next_joint_state_pred.append(cur_jt_pred_output)
        tot_next_joint_state_pred = torch.cat(
            tot_next_joint_state_pred, dim=0 # (nn_joints_to_pred, )
        )
        return tot_next_joint_state_pred
        
    
    def deploy(self):
        
        if self.hand_type == 'allegro':
            self.deploy_allegro()
            exit(0)
        
        import rospy
        from hora.algo.deploy.robots.leap import LeapNode
        # try to set up rospy
        rospy.init_node('example')
        
        print(f"Registering leap node")
        self.leap_hand = LeapNode()
        print(f"Registered leap node")
        
        
        # Wait for connections.
        rospy.sleep(0.5)
        
        cur_ts_idx = 0

        hz = 20
        ros_rate = rospy.Rate(hz)

        # # command to the initial position
        # for t in range(hz * 4):
        #     tprint(f'setup {t} / {hz * 4}')
        #     allegro.command_joint_position(self.init_pose)
        #     obses, _ = allegro.poll_joint_position(wait=True)
        #     ros_rate.sleep()
        
        # # command to the initial position
        self.leap_hand.set_default_pds()
        
        grasp_cache_name = 'cache/leap_down_init0d375_grasp_50k_s08.npy'
        
        
        if self.preset_grasp_data_fn is None:
            grasp_data = np.load(grasp_cache_name).astype(np.float32)
        else:
            grasp_data = np.load(self.preset_grasp_data_fn).astype(np.float32)
        
        
        self.grasp_data = grasp_data
        sampled_idx = np.random.randint(0, grasp_data.shape[0])
        sampled_idx = 10 # very good !
        sampled_idx = 20
        # sampled_idx = 90
        
        if self.test_openloop_replay:
            data_collect_st_time = time.time()
            last_data_collect_time = time.time()
            tot_data_collect_times = []
            cur_targets_to_set = self.replay_actions[self.test_exp_idx, 0, : 16].detach().cpu().numpy()[self.joint_idxes_ordering].copy()
            cur_ts_idx += 1 
        else:
            cur_targets_to_set = grasp_data[sampled_idx, :16][self.joint_idxes_ordering].copy()
        cur_obj_pose = grasp_data[sampled_idx, 16: ].copy()
        cur_obj_pose = torch.from_numpy(cur_obj_pose).float().to(self.device)
        
        
        
        if self.replay_wave_type == 'sine':
            cur_targets_to_set = np.zeros_like(cur_targets_to_set)
            self.sampled_omega_magnitude = torch.rand(2)
            omega_coef, magnitude_coef = self.sampled_omega_magnitude[0].item(), self.sampled_omega_magnitude[1].item() # two values #
            sampled_omega = self.sine_omega_lower + (self.sine_omega_upper - self.sine_omega_lower) * omega_coef
            sampled_magnitude = self.sine_magnitude_lower + (self.sine_magnitude_upper - self.sine_magnitude_lower) * magnitude_coef
            # magnitude * torch.sin(2 * np.pi * sampled_oemga * (cur_step / float(100))) #
            #  
            self.replay_joint_idx = torch.randint(0, len(self.to_replay_single_joint_idxes), ()).item() # get the replay joint index #
            
            self.replay_joint_idx = self.to_replay_single_joint_idxes[self.replay_joint_idx]
            
            cur_targets_to_set = torch.from_numpy(cur_targets_to_set).float().to(self.device)
            cur_targets_to_set = self._scale(cur_targets_to_set, self.allegro_dof_lower, self.allegro_dof_upper)
            cur_targets_to_set = cur_targets_to_set.detach().cpu().numpy()[self.joint_idxes_ordering].copy()
        
        self.init_obj_pose = grasp_data[sampled_idx, 16: ].copy()
        
        self.last_time = time.time()
        self.leap_hand.set_allegro(cur_targets_to_set)
        time.sleep(0.5)
        self.leap_hand.set_customized_dps()
        
        # initial leap hand pose observation
        real_leap_pos = self.leap_hand.read_pos()

        obses = self._obs_leap2hora(real_leap_pos)

        # initialize allegro #
        # obses, _ = allegro.poll_joint_position(wait=True)
        # obses = _obs_allegro2hora(obses)
        
        obses = torch.from_numpy(obses.astype(np.float32)).cuda()

        
        if self.invdyn_pred_extrin:
            cur_obs_buf = self._unscale(obses.clone(), self.allegro_dof_lower, self.allegro_dof_upper)
            cur_qtar_buf = cur_obs_buf.clone()
            
            if self.normalize_input:
                self.history_qpos_buf[:, :] = cur_obs_buf.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            else:
                self.history_qpos_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            self.history_qtars_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qtars_buf.size(0), 1)
            
            self.extrin_history_qpos_buf[:, :] = cur_obs_buf.unsqueeze(0).repeat(self.extrin_history_qpos_buf.size(0), 1)
            self.extrin_history_qtars_buf[:, :] = obses.unsqueeze(0).repeat(self.extrin_history_qtars_buf.size(0), 1)
            
            ###### history obj pose buf ######
            self.history_obj_pose_buf[:, :] = cur_obj_pose.unsqueeze(0).repeat(self.history_obj_pose_buf.size(0), 1).contiguous()
        else:
            if self.normalize_input:
                cur_obs_buf = self._unscale(obses.clone(), self.allegro_dof_lower, self.allegro_dof_upper)
                self.history_qpos_buf[:, :] = cur_obs_buf.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            else:
                self.history_qpos_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qpos_buf.size(0), 1)
            self.history_qtars_buf[:, :] = obses.unsqueeze(0).repeat(self.history_qtars_buf.size(0), 1)
            
            ###### history obj pose buf ######
            self.history_obj_pose_buf[:, :] = cur_obj_pose.unsqueeze(0).repeat(self.history_obj_pose_buf.size(0), 1).contiguous()


        if self.hist_context_length > 0:
            # hist_context_qpos_buf
            self.hist_context_qpos_buf[:, :] = obses.contiguous().unsqueeze(0).repeat(self.hist_context_length, 1).contiguous()
            self.hist_context_qtars_buf[:, :] = obses.contiguous().unsqueeze(0).repeat(self.hist_context_length, 1).contiguous()


        self.rolling_out_qpos = [ obses.clone() ]
        self.rolling_out_qtars = [ obses.clone() ]
        self.rolling_out_pred_sim_states = [ obses.clone() ]
        self.rolling_out_pred_real_wm_states_wo_compensator = [ obses.clone() ]
        
        experience_idx = 0
        bad_experiences_idx =  0

        while True:
            flatten_history_qpos = self.history_qpos_buf.contiguous().view(-1).contiguous()
            flatten_history_qtars = self.history_qtars_buf.contiguous().view(-1).contiguous()
            
            ### add the control err ###
            flatten_history_control_err = (self.history_qpos_buf - self.history_qtars_buf).contiguous().view(-1).contiguous()
            ### add history object pose ###
            flatten_histroy_obj_pose = self.history_obj_pose_buf[..., 3:].contiguous().view(-1).contiguous()
            flatten_obj_state_perdictor_input = torch.cat(
                [ flatten_history_qpos, flatten_history_qtars, flatten_history_control_err, flatten_histroy_obj_pose ], dim=-1
            )
            if self.obj_state_predictor:
                pred_obj_state = self.invdynmodel_obj_state_predictor.forward_states_for_actions(
                    flatten_obj_state_perdictor_input.unsqueeze(0),
                )
                print(f"pred_obj_state: {pred_obj_state.size()}")
                pred_obj_state = pred_obj_state[0, ]
                pred_obj_state = pred_obj_state.contiguous().view(self.invdyn_config_objstate_predictor.invdyn.future_length, -1).contiguous()
                pred_obj_state = pred_obj_state[0]
                self.history_obj_pose_buf[:-1, :] = self.history_obj_pose_buf[1:, :].clone()
                self.history_obj_pose_buf[-1, 3:] = pred_obj_state.detach().clone()
                print(pred_obj_state)
            ### add the control err ###
            
            flatten_future_obj_motion = self.future_obj_motion_buf.contiguous().view(-1).contiguous()
            history_input = torch.cat([flatten_history_qpos, flatten_history_qtars], dim=0)
            
            if self.w_hand_root_ornt:
                history_input = torch.cat([history_input, self.hand_root_tsr], dim=-1)
            
            history_input = history_input.unsqueeze(0)
            future_input = flatten_future_obj_motion.unsqueeze(0)
            
            if self.mask_obj_motion:
                future_input  = future_input * 0.0
                
            if self.invdyn_pred_extrin:
                history_extrin = torch.cat(
                    [ self.extrin_history_qpos_buf, self.extrin_history_qtars_buf ], dim=-1
                )
                history_extrin = history_extrin.unsqueeze(0)
            else:
                history_extrin = None
                
            hist_context = None
            if self.hist_context_length > 0:
                unscaled_context_qpos = self._unscale(self.hist_context_qpos_buf, self.allegro_dof_lower, self.allegro_dof_upper).float()
                hist_context = torch.cat(
                    [ unscaled_context_qpos, self.hist_context_qtars_buf ], dim=-1
                ).unsqueeze(0)
            
            
            target = self.invdynmodel.forward_states_for_actions(history_input, future_input, history_extrin=history_extrin, hist_context=hist_context)
            target = target[..., :16]
            
            if self.tune_bc_via_compensator_model:
                bc_compensator_input = torch.cat(
                    [ history_input, target ], dim=-1
                )
                bc_compensator_input = self.bc_compensator_running_mean_std(bc_compensator_input)
                bc_compensator_output = self.bc_compensator_model(bc_compensator_input)
                # the output of the compensator mode #
                bc_compensator_output = self.bc_compensator_out_mus(bc_compensator_output)
                bc_compensator_scale = 1./256
                target = target + bc_compensator_output * bc_compensator_scale
            
            target = torch.clip(target, self.allegro_dof_lower, self.allegro_dof_upper)
            
            
            if self.invdyn_pred_extrin:
                if self.normalize_output:
                    target = self._scale(target, self.allegro_dof_lower, self.allegro_dof_upper)
            
            if self.test_openloop_replay:
                target = self.replay_actions[self.test_exp_idx, cur_ts_idx, :16].unsqueeze(0)
                if self.openloop_replay_add_noise:
                    if np.random.uniform(0, 1) < 0.5:
                        target = target + torch.randn_like(target) * self.openloop_replay_noise_scale
                        # np.random.normal(0, self.openloop_replay_noise_scale, commands.shape)
                if self.replay_wave_type == 'sine':
                    cur_sampled_joint_pos = sampled_magnitude * np.sin(2 * sampled_omega * (float(cur_ts_idx) / float(100)) * np.pi)
                    target = torch.zeros_like(target)

                    target[0, self.replay_joint_idx] = cur_sampled_joint_pos # get the sampledj oint pose
                    target = self._scale(target, self.allegro_dof_lower, self.allegro_dof_upper).float()
                    # pass
            
            
            if self.use_delta_action_model and (not self.use_delta_action_model_policy):
                original_target = target.clone()
                
                ### multi action compensator is not used currently ###
                if self.multi_action_compensator_w_full_hand:
                    cur_qpos_buf = self.history_qpos_buf[-self.compensator_history_length: , :].clone()
                    cur_qpos_buf = self._unscale(cur_qpos_buf, self.allegro_dof_lower, self.allegro_dof_upper).float()
                    if self.compensator_history_length == 1:
                        cur_qtars_buf = target.clone()
                    else:
                        cur_qtars_buf = self.history_qtars_buf[-self.compensator_history_length + 1:, :].clone()
                        cur_qtars_buf = torch.cat([cur_qtars_buf, target], dim=0)
                    cur_qtars_buf = self._unscale(cur_qtars_buf, self.allegro_dof_lower, self.allegro_dof_upper).float()
                    
                    cur_qpos_buf = cur_qpos_buf.contiguous().view(1, -1).contiguous()
                    cur_qtars_buf = cur_qtars_buf.contiguous().view(1, -1).contiguous()
                    delta_action_input_dict = { 'state': cur_qpos_buf, 'action': cur_qtars_buf }
                    
                    compenator_selector_input = torch.cat(
                        [cur_qpos_buf, cur_qtars_buf], dim=-1
                    )
                    selector_output = self.compensator_weight_mlp(compenator_selector_input)
                    selector_output = torch.softmax(selector_output, dim=-1)
                    print(f"selector_output: {selector_output}")
                    tot_compensated_action = []
                    for compensator_idx in self.compensator_idx_to_compensator:
                        cur_output = self.compensator_idx_to_compensator[compensator_idx](delta_action_input_dict)
                        cur_output = torch.clamp(cur_output, -1.0, 1.0)
                        tot_compensated_action.append(cur_output)
                    tot_compensated_action = torch.stack(tot_compensated_action, dim=1)
                    tot_compensated_action = tot_compensated_action * selector_output.unsqueeze(-1)
                    tot_compensated_action = tot_compensated_action.sum(dim=1)
                    pred_delta_action = tot_compensated_action * self.action_compensator_scale
                    pred_delta_action = pred_delta_action * self.action_compensator_joint_mask.unsqueeze(0)
                    action_w_delta_action = original_target + pred_delta_action
                    
                elif self.action_compensator_w_full_hand:
                    
                    if self.test_openloop_replay:
                        replay_state_ts_idxes = [ _ for _ in range(cur_ts_idx - self.compensator_history_length, cur_ts_idx)  ]
                        replay_state_ts_idxes = [ max(_, 0) for _ in replay_state_ts_idxes ]
                        replay_action_ts_idxes = [ _ for _ in range(cur_ts_idx - self.compensator_history_length + 1,  cur_ts_idx + 1)  ]
                        replay_action_ts_idxes = [ max(_, 0) for _ in replay_action_ts_idxes ]
                        cur_qpos_buf = self.replay_states[self.test_exp_idx, replay_state_ts_idxes, :].clone()
                        cur_qtars_buf = self.replay_actions[self.test_exp_idx, replay_action_ts_idxes, :].clone()
                        cur_qpos_buf = self._unscale(cur_qpos_buf, self.allegro_dof_lower, self.allegro_dof_upper).float()
                        cur_qtars_buf = self._unscale(cur_qtars_buf, self.allegro_dof_lower, self.allegro_dof_upper).float()
                    else:
                        cur_qpos_buf = self.history_qpos_buf[-self.compensator_history_length: , :].clone()
                        cur_qpos_buf = self._unscale(cur_qpos_buf, self.allegro_dof_lower, self.allegro_dof_upper).float()
                        if self.compensator_history_length == 1:
                            cur_qtars_buf = target.clone()
                        else:
                            cur_qtars_buf = self.history_qtars_buf[-self.compensator_history_length + 1:, :].clone()
                            cur_qtars_buf = torch.cat([cur_qtars_buf, target], dim=0) # qtars buf
                        cur_qtars_buf = self._unscale(cur_qtars_buf, self.allegro_dof_lower, self.allegro_dof_upper).float()
                    
                    cur_qpos_buf = cur_qpos_buf.contiguous().view(1, -1).contiguous()
                    cur_qtars_buf = cur_qtars_buf.contiguous().view(1, -1).contiguous()
                    delta_action_input_dict = { 'state': cur_qpos_buf, 'action': cur_qtars_buf }
                    cur_out_delta_action = self.delta_action_model_full_hand(delta_action_input_dict)
                    cur_out_delta_action = torch.clamp(cur_out_delta_action, -1.0, 1.0)
                    
                    pred_delta_action = cur_out_delta_action * self.action_compensator_scale
                    
                    pred_delta_action = pred_delta_action * self.action_compensator_joint_mask.unsqueeze(0)
                    action_w_delta_action = original_target + pred_delta_action
                
                else:
                    if self.wm_history_length == 1:
                        cur_state = self.history_qpos_buf[-1, :].clone().unsqueeze(0)
                        cur_target = target.clone() # .unsqueeze(0)
                        
                        unscaled_state = self._unscale(cur_state, self.allegro_dof_lower, self.allegro_dof_upper).float()
                        unscaled_target = self._unscale(cur_target, self.allegro_dof_lower, self.allegro_dof_upper).float()
                    elif self.wm_history_length > 1:
                        cur_state  = self.history_qpos_buf[-self.wm_history_length:, :].clone()
                        cur_hist_target = self.history_qtars_buf[-self.wm_history_length + 1: , :].clone()
                        cur_target = torch.cat(
                            [ cur_hist_target, target.clone() ], dim=0
                        )
                        unscaled_state = self._unscale(cur_state, self.allegro_dof_lower, self.allegro_dof_upper).float()
                        unscaled_target = self._unscale(cur_target, self.allegro_dof_lower, self.allegro_dof_upper).float()
                        # unscaled_state = unscaled_state.contiguous().view( -1).contiguous().unsqueeze(0)
                        # unscaled_target = unscaled_target.contiguous().view( -1).contiguous().unsqueeze(0)
                    input_dict_delta_action = {
                        'state': unscaled_state, 'action': unscaled_target
                    }
                    
                    if self.joint_idx >= 0:
                        if self.delta_action_model_using_abs_target:
                            action_w_delta_action = target.clone()
                        else:   
                            action_w_delta_action = unscaled_target[-1:].clone()
                        for joint_idx in self.joint_idx_to_delta_action_model:
                            finger_joint_idxes = [joint_idx]
                            finger_joint_idxes = torch.tensor(finger_joint_idxes).long().to(self.device)
                            if self.wm_history_length == 1:
                                input_dict_delta_action = {
                                    'state': unscaled_state[:, finger_joint_idxes], 'action': unscaled_target[:, finger_joint_idxes]
                                }
                            elif self.wm_history_length > 1:
                                input_dict_delta_action = {
                                    'state': unscaled_state[:, finger_joint_idxes].contiguous().view(-1).unsqueeze(0), 'action': unscaled_target[:, finger_joint_idxes].contiguous().view(-1).unsqueeze(0)
                                }
                            if self.hist_context_length > 0:
                                
                                input_dict_delta_action.update( {
                                    'hist_state': self.history_qpos_buf[-self.hist_context_length:, :][..., self.hist_context_finger_joint_idxes].contiguous().view(-1).unsqueeze(0), 'hist_action': self.history_qtars_buf[-self.hist_context_length : , :][..., self.hist_context_finger_joint_idxes].contiguous().view(-1).unsqueeze(0) }
                                )
                            pred_delta_action_joint = self.joint_idx_to_delta_action_model[joint_idx](input_dict_delta_action)
                            pred_delta_action_joint = torch.clamp(pred_delta_action_joint, -1.0, 1.0)
                            action_w_delta_action[:, finger_joint_idxes] = action_w_delta_action[:, finger_joint_idxes] + pred_delta_action_joint[:] * self.action_compensator_scale
                        if self.delta_action_model_using_abs_target:
                            pred_delta_action = action_w_delta_action  - target
                        else:
                            pred_delta_action = action_w_delta_action - unscaled_target
                        
                    elif self.finger_idx >= 0:
                        action_w_delta_action = unscaled_target[-1: ].clone()
                        # finger_idx_to_delta_action_model
                        for finger_idx in self.finger_idx_to_delta_action_model:
                            finger_joint_idxes = [_ for _ in range(4 * finger_idx, 4 * (finger_idx + 1))]
                            finger_joint_idxes = torch.tensor(finger_joint_idxes).long().to(self.device)
                            if self.wm_history_length == 1:
                                input_dict_delta_action = {
                                    'state': unscaled_state[:, finger_joint_idxes], 'action': unscaled_target[:, finger_joint_idxes]
                                }
                            elif self.wm_history_length > 1:
                                input_dict_delta_action = {
                                    'state': unscaled_state[:, finger_joint_idxes].contiguous().view(-1).unsqueeze(0), 'action': unscaled_target[:, finger_joint_idxes].contiguous().view(-1).unsqueeze(0)
                                }
                            if self.hist_context_length > 0:
                                input_dict_delta_action.update(
                                    { 'hist_state': self.history_qpos_buf[-self.hist_context_length:, :][..., self.hist_context_finger_joint_idxes].contiguous().view(-1).unsqueeze(0), 'hist_action': self.history_qtars_buf[-self.hist_context_length + 1: , :][..., self.hist_context_finger_joint_idxes].contiguous().view(-1).unsqueeze(0) }
                                )
                            pred_delta_action_finger = self.finger_idx_to_delta_action_model[finger_idx](input_dict_delta_action)
                            action_w_delta_action[:, finger_joint_idxes] = action_w_delta_action[:, finger_joint_idxes] + pred_delta_action_finger[:]
                        pred_delta_action = action_w_delta_action - unscaled_target
                    else:
                        pred_delta_action = self.world_model_delta_actions(input_dict_delta_action) # 
                        # print(f"pred_delta_action: {pred_delta_action}")
                        action_w_delta_action = unscaled_target + pred_delta_action
                
                
                acc_act_max_val = 0.05
                # acc_act_max_val = 0.2
                if torch.max(torch.abs(pred_delta_action[0])).item() <= acc_act_max_val:
                    # print(f"pred_delta_action: {pred_delta_action}")
                    if self.delta_action_model_using_abs_target:
                        target = action_w_delta_action
                    else:
                        target = self._scale(action_w_delta_action, self.allegro_dof_lower, self.allegro_dof_upper).float()
                    target = torch.clip(target, self.allegro_dof_lower, self.allegro_dof_upper)
            
                delta_act_abs = torch.abs(pred_delta_action)
                self.compensated_res_abs_running_mean_std(delta_act_abs)
                if self.test_openloop_replay:
                    pass
                else:
                    print(f"running mean of compensated delta action: {self.compensated_res_abs_running_mean_std.running_mean}") 
            
            
            
            if self.normalize_output:
                commands = self._scale(target, self.allegro_dof_lower, self.allegro_dof_upper)
                commands = commands.detach().cpu().numpy()[0]
            else:
                commands = target.detach().cpu().numpy()[0]
                
            
            if self.test_with_sim_wm_for_prediction:
                future_ref = original_target.clone() if self.use_delta_action_model else target.clone()
                pred_nex_state = self.invdyn_world_model.forward_states_for_actions(history_input.float(), future_ref.float(), history_extrin=None)
                
            
            if self.test_openloop_replay:
                transformed_states = self.replay_states[self.test_exp_idx, cur_ts_idx, :16].unsqueeze(0).detach().cpu().numpy()[0]
                transformed_states = self._action_hora2leap(transformed_states)
                commands = self._action_hora2leap(commands)
                # if self.openloop_replay_add_noise:
                #     if np.random.uniform(0, 1) < 0.5:
                #         commands = commands + np.random.normal(0, self.openloop_replay_noise_scale, commands.shape)
                commands = self.openloop_replay_interp_coef * transformed_states + (1.0 - self.openloop_replay_interp_coef) * commands
            else:
                # commands = commands[self.joint_idxes_ordering]
                commands = self._action_hora2leap(commands)
            
            
            cur_time = time.time()
            # print(f"delta time: {cur_time - self.last_time}")
            self.last_time = cur_time
            
            self.leap_hand.set_allegro(commands)
            
            
            # commands = _action_hora2allegro(commands)
            # allegro.command_joint_position(commands)
            ros_rate.sleep()  # keep 20 Hz command
            # get o_{t+1}
            
            # obses, torques = allegro.poll_joint_position(wait=True)
            # obses = _obs_allegro2hora(obses)
            
            obses = self.leap_hand.read_pos()
            obses = self._obs_leap2hora(obses)
            
            obses = torch.from_numpy(obses.astype(np.float32)).cuda()
            
            
            
            if self.use_delta_action_model and self.test_with_real_wm_for_prediction:
                real_wm_pred_output = self.get_real_wm_pred_output(original_target)
                print(f"real wm pred output: {real_wm_pred_output}")
                pred_obses_wo_compensator = obses.clone() # 
                pred_obses_wo_compensator[self.real_wm_joint_idxes] = real_wm_pred_output.clone() # predicted nex state #
                self.rolling_out_pred_real_wm_states_wo_compensator.append(pred_obses_wo_compensator.clone())

            if self.rolling_out_policy or self.test_openloop_replay:
                self.rolling_out_qpos.append(obses.clone())
                self.rolling_out_qtars.append(target[0].clone())
                if self.test_with_sim_wm_for_prediction:
                    self.rolling_out_pred_sim_states.append(pred_nex_state[0].clone())
            
            
            if self.use_delta_action_model:
                target = original_target.clone() 

            
            
            if self.invdyn_pred_extrin:
                cur_obs_buf = self._unscale(obses.clone(), self.allegro_dof_lower, self.allegro_dof_upper)
                cur_qtar_buf = self._unscale(target.clone(), self.allegro_dof_lower, self.allegro_dof_upper)
            
                self.history_qpos_buf[:-1, :] = self.history_qpos_buf[1:, :].clone()
                
                if self.normalize_input:
                    self.history_qpos_buf[-1, :] = cur_obs_buf.detach().clone()
                else:
                    self.history_qpos_buf[-1, :] = obses.detach().clone()
                self.history_qtars_buf[:-1, :] = self.history_qtars_buf[1:, :].clone()
                self.history_qtars_buf[-1, :] = target.detach().clone()
                
                
                self.extrin_history_qpos_buf[:-1, :] = self.extrin_history_qpos_buf[1:, :].clone()
                self.extrin_history_qpos_buf[-1, :] = cur_obs_buf.detach().clone()
                self.extrin_history_qtars_buf[:-1, :] = self.extrin_history_qtars_buf[1:, :].clone()
                self.extrin_history_qtars_buf[-1, :] = target.detach().clone()
            else:
                self.history_qpos_buf[:-1, :] = self.history_qpos_buf[1:, :].clone()
                
                if self.normalize_input:
                    cur_obs_buf = self._unscale(obses.clone(), self.allegro_dof_lower, self.allegro_dof_upper)
                    self.history_qpos_buf[-1, :] = cur_obs_buf.detach().clone()
                else:
                    self.history_qpos_buf[-1, :] = obses.detach().clone()
                
                self.history_qtars_buf[:-1, :] = self.history_qtars_buf[1:, :].clone()
                self.history_qtars_buf[-1, :] = target.detach().clone()

                if self.hist_context_length > 0: # update buffer #
                    self.hist_context_qpos_buf[:-1, :] = self.hist_context_qpos_buf[1:, :].clone()
                    self.hist_context_qpos_buf[-1, :] = obses.detach().clone()
                    self.hist_context_qtars_buf[:-1, :] = self.hist_context_qtars_buf[1:, :].clone()
                    self.hist_context_qtars_buf[-1, :] = target.detach().clone()



            if self.test_openloop_replay:
                cur_ts_idx += 1
                
                if cur_ts_idx >= self.replay_actions.shape[1]:
                    
                    tot_rolling_out_qpos = torch.stack(self.rolling_out_qpos, dim=0)
                    tot_rolling_out_qtars = torch.stack(self.rolling_out_qtars, dim=0)
                    

                    tot_rolling_out_dict = {
                                            'qpos': tot_rolling_out_qpos.detach().cpu().numpy(), 'qtars': tot_rolling_out_qtars.detach().cpu().numpy(),
                                            }
                    if self.test_with_sim_wm_for_prediction:
                        cat_rolling_out_pred_next_states = torch.stack(self.rolling_out_pred_sim_states, dim=0)
                        tot_rolling_out_dict.update(
                            { 'pred_next_states': cat_rolling_out_pred_next_states.detach().cpu().numpy() }
                        )
                    if self.use_delta_action_model and self.test_with_real_wm_for_prediction:   
                        cat_rolling_out_pred_real_wm_states_wo_compensator = torch.stack(self.rolling_out_pred_real_wm_states_wo_compensator, dim=0)
                        tot_rolling_out_dict.update(
                            { 'pred_real_wm_states_wo_compensator': cat_rolling_out_pred_real_wm_states_wo_compensator.detach().cpu().numpy() }
                        )
                    
                    
                    if self.use_delta_action_model:
                        
                        sv_openloop_replay_folder = self.openloop_replay_sv_folder
                    else:
                        
                        sv_openloop_replay_folder = self.openloop_replay_sv_folder
                    
                    os.makedirs(sv_openloop_replay_folder, exist_ok=True)
                    sv_openloop_replay_name = f"openloop_replay_{self.test_exp_idx}.npy"
                    sv_openloop_replay_path = os.path.join(sv_openloop_replay_folder, sv_openloop_replay_name)
                    
                    tot_data_collect_times.append(time.time() - last_data_collect_time)
                    last_data_collect_time = time.time()
                    
                    print(f"Saving openloop replay data to {sv_openloop_replay_path}, Time passed: {time.time( ) - data_collect_st_time}")
                    np.save(sv_openloop_replay_path, tot_rolling_out_dict)
                    

                    
                    cur_ts_idx = 0
                    self.test_exp_idx += 1
                    
                    
                    if self.test_exp_idx - self.init_test_exp_idx >= 500 or self.test_exp_idx >= self.replay_actions.shape[0]:
                        
                        avg_data_collect_time = np.mean(tot_data_collect_times)
                        std_data_collect_time = np.std(tot_data_collect_times)
                        print(f"Avg data collect time: {avg_data_collect_time}, Std data collect time: {std_data_collect_time}")
                        
                        break
                    
                    cur_targets_to_set = self.replay_actions[self.test_exp_idx, cur_ts_idx, : 16].detach().cpu().numpy()[self.joint_idxes_ordering].copy()
                    
                    if self.replay_wave_type == 'sine':
                        cur_targets_to_set = np.zeros_like(cur_targets_to_set) # [self.joint_idxes_ordering].copy() # zero out the target joint pos # zero out the target joint pose
                        self.sampled_omega_magnitude = torch.rand(2) # .to(cur_targets_to_set.device)
                        omega_coef, magnitude_coef = self.sampled_omega_magnitude[0].item(), self.sampled_omega_magnitude[1].item() # two values #
                        sampled_omega = self.sine_omega_lower + (self.sine_omega_upper - self.sine_omega_lower) * omega_coef
                        sampled_magnitude = self.sine_magnitude_lower + (self.sine_magnitude_upper - self.sine_magnitude_lower) * magnitude_coef
                        # magnitude * torch.sin(2 * np.pi * sampled_oemga * (cur_step / float(100))) #
                        
                        self.replay_joint_idx = torch.randint(0, len(self.to_replay_single_joint_idxes), ()).item() # get the replay joint index #
                        self.replay_joint_idx = self.to_replay_single_joint_idxes[self.replay_joint_idx]
                        
                        cur_targets_to_set = torch.from_numpy(cur_targets_to_set).float().to(self.device)
                        cur_targets_to_set = self._scale(cur_targets_to_set, self.allegro_dof_lower, self.allegro_dof_upper)
                        cur_targets_to_set = cur_targets_to_set.detach().cpu().numpy()[self.joint_idxes_ordering].copy()
                    
                    
                    self.leap_hand.set_allegro(cur_targets_to_set)
                    time.sleep(0.5)
                    
                    real_leap_pos = self.leap_hand.read_pos()
                    obses = self._obs_leap2hora(real_leap_pos)
                    
                    obses = torch.from_numpy(obses.astype(np.float32)).cuda()
                    
                    self.rolling_out_qpos = [obses.clone()]
                    self.rolling_out_qtars = [obses.clone()]
                    self.rolling_out_pred_sim_states = [obses.clone()]
                    self.rolling_out_pred_real_wm_states_wo_compensator = [obses.clone()]
                    
                    cur_ts_idx += 1
            
            elif self.rolling_out_policy:
                cur_ts_idx += 1
                if cur_ts_idx >= 401:
                    while (not self.X_pressed) and (not self.Y_pressed):
                        self.parse_quality_selection_button()
                        # ros_rate.sleep()
                    
                    cat_rolling_out_qpos = torch.stack(self.rolling_out_qpos, dim=0)
                    cat_rolling_out_qtars = torch.stack(self.rolling_out_qtars, dim=0)
                    cat_rolling_out_dict = { # from the grasping pose data #
                        'qpos': cat_rolling_out_qpos.detach().cpu().numpy(), 'qtars': cat_rolling_out_qtars.detach().cpu().numpy(), 
                        'init_obj_pose': self.init_obj_pose # so I think that we shouldsave init obj pose and that should be set to the object in the model's resetting process #
                    }
                    
                    if self.test_with_sim_wm_for_prediction:
                        cat_rolling_out_pred_next_states = torch.stack(self.rolling_out_pred_sim_states, dim=0)
                        cat_rolling_out_dict.update(
                            { 'pred_next_states': cat_rolling_out_pred_next_states.detach().cpu().numpy() }
                        )
                        
                    if self.use_delta_action_model and self.test_with_real_wm_for_prediction:
                        cat_rolling_out_pred_real_wm_states_wo_compensator = torch.stack(self.rolling_out_pred_real_wm_states_wo_compensator, dim=0)
                        cat_rolling_out_dict.update(
                                { 'pred_real_wm_states_wo_compensator': cat_rolling_out_pred_real_wm_states_wo_compensator.detach().cpu().numpy() }
                        )
                        diff_qpos_w_sim_qpos_wo_compensator = torch.norm(
                            ( cat_rolling_out_pred_real_wm_states_wo_compensator[..., self.real_wm_joint_idxes] - cat_rolling_out_pred_next_states[..., self.real_wm_joint_idxes] ), p=2, dim=-1
                        )
                        diff_qpos_w_sim_qpos_wo_compensator = diff_qpos_w_sim_qpos_wo_compensator.mean()
                        
                        diff_qpos_w_sim_qpos_w_compensator = torch.norm(
                            ( cat_rolling_out_qpos[..., self.real_wm_joint_idxes] - cat_rolling_out_pred_next_states[..., self.real_wm_joint_idxes] ), p=2,  dim=-1
                        )
                        diff_qpos_w_sim_qpos_w_compensator = diff_qpos_w_sim_qpos_w_compensator.mean()
                        print(f"diff_qpos_w_sim_qpos_wo_compensator: {diff_qpos_w_sim_qpos_wo_compensator.item()}, diff_qpos_w_sim_qpos_w_compensator: {diff_qpos_w_sim_qpos_w_compensator.item()}")
                    
                    if self.X_pressed:
                        print(f"good experiences")
                        
                        cat_rolling_out_dict_sv_folder = self.cat_rolling_out_dict_sv_folder #  "./cache/good_experiences_cuboid_thin_scl0d5"
                        os.makedirs(cat_rolling_out_dict_sv_folder, exist_ok=True)
                        cat_rolling_out_dict_sv_name = f"good_experiences_{experience_idx}.npy"
                        cat_rolling_out_dict_sv_path = os.path.join(cat_rolling_out_dict_sv_folder, cat_rolling_out_dict_sv_name)
                        np.save(cat_rolling_out_dict_sv_path, cat_rolling_out_dict)
                        experience_idx += 1
                        print(f"Good experiences saved to {cat_rolling_out_dict_sv_path}")

                    else:
                        cat_rolling_out_dict_sv_folder_bad  = self.cat_rolling_out_dict_sv_folder_bad # 
                        os.makedirs(cat_rolling_out_dict_sv_folder_bad, exist_ok=True)
                        cat_rolling_out_dict_sv_name = f"bad_experiences_{bad_experiences_idx}.npy"
                        cat_rolling_out_dict_sv_path = os.path.join(cat_rolling_out_dict_sv_folder_bad, cat_rolling_out_dict_sv_name)
                        np.save(cat_rolling_out_dict_sv_path, cat_rolling_out_dict)
                        bad_experiences_idx += 1
                        print(f"bad experiences saved to {cat_rolling_out_dict_sv_path}")
                    
                    
                    self.X_pressed = self.Y_pressed = False
                    cur_ts_idx = 1
                    sampled_idx =( sampled_idx + 1) % self.grasp_data.shape[0]
                    self.reset_to_initial_state(sampled_idx)
            
            
            if rospy.is_shutdown():
                exit(0)


    def restore(self, fn):
        checkpoint = torch.load(fn)
        self.running_mean_std.load_state_dict(checkpoint['running_mean_std'])
        self.model.load_state_dict(checkpoint['model'])
        self.sa_mean_std.load_state_dict(checkpoint['sa_mean_std'])


